import os

import nibabel as nib


def removeClass(class_ids, file_path):
    nii = nib.load(file_path)
    data = nii.get_fdata()
    for class_id in class_ids:
        data[data == class_id] = 0
    nii = nib.Nifti1Image(data, nii.affine, nii.header)
    nib.save(nii, file_path)


def removeClassInDir(class_ids, dir_path):
    for file_name in os.listdir(dir_path):
        if file_name.endswith('.nii.gz'):
            file_path = os.path.join(dir_path, file_name)
            removeClass(class_ids, file_path)
